from .base_reasoner import BaseReasoner, ReasoningNode
import asyncio
import argparse
import json
import os
import re
import time
import traceback
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
import random
import aiohttp
from collections import defaultdict
from datetime import datetime

class OpenBookQAReasoner(BaseReasoner):
    def __init__(self):
        super().__init__("OpenBookQA")
        self.config.dataset_path = "datasets/OpenBookQA.json"
    
    async def load_problems(self, start_idx: int, end_idx: int) -> List[Dict]:
        """Load OpenBookQA problems from dataset"""
        try:
            with open(self.config.dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []
    
    async def execute_workflow(self, problem: Dict[str, Any]) -> Dict[str, Any]:
        """Execute full reasoning workflow with voting from three methods"""
        try:
            # Run all three methods in parallel
            ho1_result, ho2_result, cot_result = await asyncio.gather(
                self._execute_ho1_workflow(problem),
                self._execute_ho2_workflow(problem),
                self._execute_cot_workflow(problem)
            )
            
            # Collect all answers
            answers = {
                "ho1": ho1_result.get("final_answer", None),
                "ho2": ho2_result.get("final_answer", None),
                "cot": cot_result.get("answer", None)
            }

            print("\nMethod Results:")
            print(f"HoT v1 Answer: {answers['ho1']}")
            print(f"HoT v2 Answer: {answers['ho2']}")
            print(f"CoT Answer: {answers['cot']}")

            # Voting
            final_answer = self._vote_on_answers(answers)
            print(f"\nFinal Voted Answer: {final_answer}")
            
            return {
                "status": "success",
                "final_answer": final_answer,
                "method_results": {
                    "ho1": ho1_result,
                    "ho2": ho2_result,
                    "cot": cot_result
                }
            }
            
        except Exception as e:
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e),
                "logs": self.logs
            }
    
    def _vote_on_answers(self, answers: Dict[str, Any]) -> str:
        """Voting mechanism to select the best answer"""
        answer_counts = defaultdict(int)
        for method, answer in answers.items():
            if answer is not None:
                answer_counts[answer] += 1
        
        if not answer_counts:
            return "X" 
        
        # Find the answer with highest count
        max_count = max(answer_counts.values())
        candidates = [ans for ans, cnt in answer_counts.items() if cnt == max_count]
        
        if len(candidates) == 1:
            return candidates[0]
        else:
            if answers["ho1"] is not None:
                return answers["ho1"]
            elif answers["ho2"] is not None:
                return answers["ho2"]
            else:
                return answers["cot"] if answers["cot"] is not None else "X"

    async def _execute_ho1_workflow(self, problem: Dict[str, Any]) -> Dict[str, Any]:
        try:
            question = problem["question_stem"]
            choices = problem["choices"]
            options = {
                "A": choices["text"][0],
                "B": choices["text"][1],
                "C": choices["text"][2],
                "D": choices["text"][3]
            }
            
            # Step 1: Create root node
            root = self._create_node(
                question=question,
                options=options,
                conditions={},
                path=[],
                method={"description": "Original problem"}
            )
            self._log_step("step1", root.node_id, {"question": question})
            
            # Step 2: Extract conditions
            conditions = await self._extract_conditions(question, options)
            root.conditions = conditions
            self._log_step("step2", root.node_id, {"conditions": conditions})
            
            # Step 3: Explore solution methods
            methods = await self._explore_solutions(question, options)
            self._log_step("step3", root.node_id, {"methods": methods})
            
            # Step 4: Create method nodes
            method_nodes = []
            for method in methods[:self.config.beam_width]:
                node = self._create_node(
                    path=[root.node_id],
                    question=question,
                    options=options,
                    method=method,
                    conditions=root.conditions,
                    score=method.get("score", 0),
                    parent_id=root.node_id
                )
                root.children.append(node.node_id)
                method_nodes.append(node)
                self._log_step("step4", node.node_id, {"method": method})
            
            # Step 5: Check classification for best method
            best_method_node = max(method_nodes, key=lambda x: x.score)
            classification = await self._check_classification(
                best_method_node.method["description"],
                question,
                options
            )
            self._log_step("step5", best_method_node.node_id, {"classification": classification})
            
            if classification["need_classify"]:
                # Step 6: Create classification nodes
                for case in classification["cases"]:
                    combined_conditions = {
                        "explicit": best_method_node.conditions.get("explicit", []).copy(),
                        "implicit": best_method_node.conditions.get("implicit", []).copy()
                    }
                    
                    for k, v in case["conditions"].items():
                        if k in combined_conditions:
                            combined_conditions[k].append(v)
                        else:
                            combined_conditions.setdefault("implicit", []).append(f"{k}: {v}")
                    
                    node = self._create_node(
                        path=best_method_node.path + [best_method_node.node_id],
                        question=question,
                        options=options,
                        method=best_method_node.method,
                        conditions=combined_conditions,
                        score=best_method_node.score,
                        parent_id=best_method_node.node_id
                    )
                    best_method_node.children.append(node.node_id)
                    self.temp_list.append(node.node_id)
                    self._log_step("step6", node.node_id, {"case": case})
            else:
                self.temp_list.append(best_method_node.node_id)
            self.temp_list = [node_id for node_id in self.temp_list if node_id not in ["N0", "N1"]]
            
            # Step 7: Solve nodes
            solutions = []
            for node_id in self.temp_list:
                solution = await self._solve_node(node_id)
                if solution:
                    solutions.append(solution)
                    self._log_step("step7", node_id, {"solution": solution})
            
            # Step 8: Aggregate answers
            final_answer = await self._aggregate_answers(solutions)
            self._log_step("step8", "system", {"final_answer": final_answer})
            
            return {
                "status": "success",
                "final_answer": final_answer,
                "nodes": self.nodes,
                "logs": self.logs
            }
            
        except Exception as e:
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e),
                "logs": self.logs
            }

    async def _execute_ho2_workflow(self, problem: Dict[str, Any]) -> Dict[str, Any]:
        """Simplified HoT workflow (version 2)"""
        try:
            question = problem["question_stem"]
            choices = problem["choices"]
            options = {
                "A": choices["text"][0],
                "B": choices["text"][1],
                "C": choices["text"][2],
                "D": choices["text"][3]
            }
            
            # Step 1: Create root node
            root = self._create_node(
                question=question,
                options=options,
                conditions={},
                path=[],
                method={"description": "Original problem"}
            )
            self._log_step("step1", root.node_id, {"question": question})
            
            # Step 2: Extract conditions
            conditions = await self._extract_conditions(question, options)
            root.conditions = conditions
            self._log_step("step2", root.node_id, {"conditions": conditions})
            
            default_method = {
                "description": "Direct reasoning with conditions",
                "steps": [
                    "Analyze question and conditions",
                    "Evaluate each option against conditions",
                    "Select best matching option"
                ],
                "score": 80,
                "score_reason": "Default method for ablation study"
            }
            
            method_node = self._create_node(
                path=[root.node_id],
                question=question,
                options=options,
                method=default_method,
                conditions=root.conditions,
                score=default_method["score"],
                parent_id=root.node_id
            )
            root.children.append(method_node.node_id)
            self._log_step("step4_ablation", method_node.node_id, {"method": default_method})
            
            # Step 5: Check classification for the method
            classification = await self._check_classification(
                method_node.method["description"],
                question,
                options
            )
            self._log_step("step5", method_node.node_id, {"classification": classification})
            
            if classification["need_classify"]:
                # Step 6: Create classification nodes
                for case in classification["cases"]:
                    combined_conditions = {
                        "explicit": method_node.conditions.get("explicit", []).copy(),
                        "implicit": method_node.conditions.get("implicit", []).copy()
                    }
                    
                    for k, v in case["conditions"].items():
                        if k in combined_conditions:
                            combined_conditions[k].append(v)
                        else:
                            combined_conditions.setdefault("implicit", []).append(f"{k}: {v}")
                    
                    node = self._create_node(
                        path=method_node.path + [method_node.node_id],
                        question=question,
                        options=options,
                        method={
                            "description": "Direct reasoning with conditions",
                            "steps": [
                                "Analyze question and conditions",
                                "Evaluate each option against conditions",
                                "Select best matching option"
                            ],
                            "score": 80,
                            "score_reason": "Case-specific solution"
                        },
                        conditions=combined_conditions,
                        score=method_node.score,
                        parent_id=method_node.node_id
                    )
                    method_node.children.append(node.node_id)
                    self.temp_list.append(node.node_id)
                    self._log_step("step6", node.node_id, {"case": case})
            else:
                combined_conditions = {
                        "explicit": method_node.conditions.get("explicit", []).copy(),
                        "implicit": method_node.conditions.get("implicit", []).copy()
                }
                node = self._create_node(
                        path=method_node.path + [method_node.node_id],
                        question=question,
                        options=options,
                        method={
                            "description": "Direct reasoning with conditions",
                            "steps": [
                                "Analyze question and conditions",
                                "Evaluate each option against conditions",
                                "Select best matching option"
                            ],
                            "score": 80,
                            "score_reason": "Case-specific solution"
                        },
                        conditions=combined_conditions,
                        score=method_node.score,
                        parent_id=method_node.node_id
                    )
                root.children.append(node.node_id)
                self.temp_list.append(node.node_id)
            self.temp_list = [node_id for node_id in self.temp_list if node_id not in ["N0", "N1"]]
            
            # Step 7: Solve nodes
            solutions = []
            for node_id in self.temp_list:
                solution = await self._solve_node(node_id)
                if solution:
                    solutions.append(solution)
                    self._log_step("step7", node_id, {"solution": solution})
            
            # Step 8: Aggregate answers
            final_answer = await self._aggregate_answers(solutions)
            self._log_step("step8", "system", {"final_answer": final_answer})
            
            return {
                "status": "success",
                "final_answer": final_answer,
                "nodes": self.nodes,
                "logs": self.logs
            }
            
        except Exception as e:
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e),
                "logs": self.logs
            }

    async def _execute_cot_workflow(self, problem: Dict[str, Any]) -> Dict[str, Any]:
        """Chain-of-Thought workflow"""
        try:
            question = problem["question_stem"]
            choices = problem["choices"]["text"]
            options = "\n".join([f"{chr(65+i)}. {choices[i]}" for i in range(4)])
            
            prompt = f""" 
Question: {question}
Options:
{options}
Let's think step by step to solve the question, give the correct answer by stating "The correct answer is [X]" where [X] is exactly one letter (A, B, C, or D)."""

            response = await self.llm.generate(prompt)
            answer = self._extract_answer(response)
            
            return {
                "response": response,
                "answer": answer
            }
            
        except Exception as e:
            print(f"CoT Error: {str(e)}")
            return {
                "status": "error",
                "message": str(e),
                "answer": None
            }

    async def _extract_conditions(self, question: str, options: Dict[str, str]) -> Dict[str, Any]:
        """Extract conditions from problem and options"""
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Analyze this question and extract key conditions:

Question: {question}
Options:
A. {options['A']}
B. {options['B']}
C. {options['C']}
D. {options['D']}

Identify:
1. Explicit conditions (directly stated)
2. Implicit conditions (logical implications)
3. Key terms and their relationships
4. Spatial/temporal relationships if present
5. Any conditional statements

Output JSON format:
{{
    "explicit": ["list", "of", "conditions"],
    "implicit": ["list", "of", "conditions"],
    "key_terms": ["term1", "term2"],
    "notes": "Analysis summary"
}}"""
        
        for attempt in range(self.config.max_retries):
            try:
                response = await self.llm.generate(prompt, response_format="json_object")
                return json.loads(response)
            except:
                continue
        
        return {
            "explicit": [],
            "implicit": [],
            "key_terms": [],
            "notes": "Failed to extract conditions"
        }
    
    async def _explore_solutions(self, question: str, options: Dict[str, str]) -> List[Dict]:
        """Step 3: Explore diverse solution methods"""
        options_text = "\n".join([f"{k}. {v}" for k, v in options.items()])
        
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Generate 3 distinct solution approaches for this question:

Question: {question}
Options:
{options_text}

For each approach, provide:
- Clear description of the reasoning strategy
- Key steps to implement the approach
- Confidence score (0-100) based on:
  * Logical soundness
  * Coverage of options
  * Appropriate use of deductive/inductive reasoning
  * Clarity of reasoning steps

Output JSON format:
{{
    "methods": [
        {{
            "description": "Approach description",
            "steps": ["step1", "step2"],
            "score": 0-100,
            "score_reason": "Scoring justification"
        }}
    ]
}}"""
        
        for attempt in range(self.config.max_retries):
            try:
                response = await self.llm.generate(prompt, response_format="json_object")
                response = response.strip()
                
                if response.startswith("```json"):
                    response = response[7:-3].strip()
                elif response.startswith("```"):
                    response = response[3:-3].strip()
                
                data = json.loads(response)
                
                if not isinstance(data, dict) or "methods" not in data:
                    raise ValueError("Invalid structure: missing 'methods' key")
                    
                methods = data["methods"]
                if len(methods) < 2:
                    raise ValueError(f"Expected at least 2 methods, got {len(methods)}")
                    
                required_keys = {"description", "steps", "score", "score_reason"}
                for method in methods:
                    if not all(k in method for k in required_keys):
                        raise ValueError("Missing required keys in method")
                    if not isinstance(method["steps"], list):
                        raise ValueError("Steps must be a list")
                        
                return sorted(methods, key=lambda x: -x["score"])
                
            except (json.JSONDecodeError, ValueError, KeyError) as e:
                print(f"Attempt {attempt + 1} failed: {str(e)}")
                if attempt == self.config.max_retries - 1:
                    print(f"Final failed response: {response}")
                    return []
                continue
                
        return [] 
    
    async def _check_classification(self, method: str, question: str, options: Dict[str, str]) -> Dict[str, Any]:
        """Step 5: Determine if classification needed"""
        options_text = "\n".join([f"{k}. {v}" for k, v in options.items()])
        
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Determine if this solution approach requires case classification:

Solution Approach: {method}
Question: {question}
Options:
{options_text}

Consider:
1. Does the question contain multiple scenarios or cases?
2. Are there conditional statements that create distinct possibilities?
3. Do the options represent different logical paths?
4. Would different initial assumptions lead to different solutions?

If classification needed, provide:
- Comprehensive case descriptions
- Precise conditions for each case
- Expected outcomes

Output JSON format:
{{
    "need_classify": true/false,
    "reason": "Classification rationale",
    "cases": [
        {{
            "description": "Case description",
            "conditions": {{"parameter": "value_range"}}
        }}
    ]
}}"""
        
        try:
            response = await self.llm.generate(prompt, response_format="json_object")
            data = json.loads(response)
            return data
        except:
            return {
                "need_classify": False,
                "reason": "Analysis failed",
                "cases": []
            }
    
    async def _solve_node(self, node_id: str) -> Optional[Dict[str, Any]]:
        """Step 7: Solve individual reasoning node"""
        node = self.nodes[node_id]
        
        context = f"Question: {node.question}\nOptions:\n"
        for opt, text in node.options.items():
            context += f"{opt}. {text}\n"
        
        context += f"\nSolution Approach: {node.method['description']}\n"
        context += f"conditions: {json.dumps(node.conditions, indent=2)}\n"
        
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Solve this question using the specified approach:

{context}

Reasoning Steps:
1. Strictly follow the provided approach: {node.method['description']}
2. Execute each step: {', '.join(node.method['steps'])}
3. Consider all conditions
4. Evaluate each option systematically
5. Provide clear justification for inclusion/exclusion
6. Select the best answer

Output Requirements:
- End your response with: "Final Answer: [OPTION]"
- Use \boxed{{[OPTION]}} to denote your answer
- Your answer must be A, B, C, or D
"""
        
        response = await self.llm.generate(prompt)
        answer = self._extract_answer(response)
        
        if answer:
            node.answer = answer
            node.state = "solved"
            return {
                "node_id": node_id,
                "response": response,
                "answer": answer
            }
        return None
    
    async def _aggregate_answers(self, solutions: List[Dict[str, Any]]) -> str:
        """Step 8: Aggregate answers from multiple nodes"""
        if not solutions:
            return "X" 
        
        # If only one solution, return it
        if len(solutions) == 1:
            return solutions[0]["answer"]
        
        # If all solutions agree, return consensus
        answers = [s["answer"] for s in solutions]
        if len(set(answers)) == 1:
            return answers[0]
        
        # Build aggregation prompt
        solutions_text = ""
        for i, sol in enumerate(solutions):
            node = self.nodes[sol["node_id"]]
            solutions_text += f"\n\nSolution {i+1} (Node {sol['node_id']}):"
            solutions_text += f"\nApproach: {node.method['description']}"
            solutions_text += f"\nconditions: {json.dumps(node.conditions, indent=2)}"
            solutions_text += f"\nAnswer: {sol['answer']}"
            solutions_text += f"\nReasoning Excerpt:\n{sol['response'][:]}..."
        
        prompt = f"""You are a top expert in formal logic, critical thinking, and argument analysis.  
You are precise, rational, and skeptical.  
You always examine each statement carefully, identify premises and conclusions, and evaluate logical validity step by step.  
You avoid unwarranted assumptions, think in terms of logical consequences, and eliminate invalid options with sound reasoning.  
You aim to reach conclusions based only on evidence and logic.  
You THINK SLOWLY, CAREFULLY, AND LOGICALLY.
Synthesize these approaches:

{solutions_text}

Instructions:
1. Analyze all solutions and their approaches
2. Identify the most reliable reasoning
3. Verify consistency with conditions
4. Select the best overall answer
5. Output format: \boxed{{[ANSWER]}}
"""
        
        response = await self.llm.generate(prompt)
        return self._extract_answer(response) or "X"
    
    def save_results(self, result: Dict[str, Any], problem: Dict[str, Any]) -> Dict[str, Any]:
        # Convert nodes to serializable format
        serialized_nodes = {}
        for node_id, node in self.nodes.items():
            serialized_nodes[node_id] = {
                "node_id": node.node_id,
                "question": node.question,
                "options": node.options,
                "method": node.method,
                "conditions": node.conditions,
                "answer": node.answer,
                "state": node.state,
                "score": node.score
            }
        
        selected_answer = result.get("final_answer", "X")
        correct_answer = problem.get("answerKey", "").strip().upper()
        is_correct = self.verify_answer(problem, selected_answer)
        verification = {
            "is_correct": is_correct,
            "correct_answer": correct_answer,
            "given_answer": selected_answer
        }
        return {
            "problem": problem,
            "result": {
                "final_answer": selected_answer,
                "correct_answer": correct_answer,
                "is_correct": is_correct,
                "nodes": serialized_nodes
            },
            "verification": verification
        }

    def _extract_answer(self, text: str) -> Optional[str]:
        """Extract answer from response text"""
        boxed_pattern = r'\\boxed\{([A-D])\}'
        boxed_match = re.search(boxed_pattern, text)
        if boxed_match:
            return boxed_match.group(1)
        
        answer_pattern = r'Answer:\s*([A-D])'
        answer_match = re.search(answer_pattern, text, re.IGNORECASE)
        if answer_match:
            return answer_match.group(1)
        
        option_pattern = r'\b([A-D])\b(?!\.\w)'
        option_match = re.search(option_pattern, text)
        if option_match:
            return option_match.group(1)
        
        return None

    def verify_answer(self, problem: Dict[str, Any], selected_answer: str) -> bool:
        """Verify if selected answer matches correct option"""
        correct_answer = problem.get("answerKey", "").strip().upper()
        return selected_answer.upper() == correct_answer.upper()